{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Layer wise learning rate settings\n",
    "\n",
    "In this tutorial, we introduce how to easily select or filter out network layers and set specific learning rate values for transfer learning.  \n",
    "MONAI provides a utility function to achieve this requirements: `generate_param_groups`, for example:\n",
    "```py\n",
    "net = Unet(spatial_dims=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1])\n",
    "print(net)  # print out network components to select expected items\n",
    "print(net.named_parameters())  # print out all the named parameters to filter out expected items\n",
    "params = generate_param_groups(\n",
    "    network=net,\n",
    "    layer_matches=[lambda x: x.model[0], lambda x: \"2.0.conv\" in x[0]],\n",
    "    match_types=[\"select\", \"filter\"],\n",
    "    lr_values=[1e-2, 1e-3],\n",
    ")\n",
    "optimizer = torch.optim.Adam(params, 1e-4)\n",
    "```\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/layer_wise_learning_rate.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow, ignite, tqdm]\"\n",
    "!python -c \"import matplotlib\" || pip install -q matplotlib\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from monai.transforms import (\n",
    "    EnsureChannelFirstd,\n",
    "    Compose,\n",
    "    LoadImaged,\n",
    "    ScaleIntensityd,\n",
    "    EnsureTyped,\n",
    ")\n",
    "from monai.optimizers import generate_param_groups\n",
    "from monai.networks.nets import DenseNet121\n",
    "from monai.inferers import SimpleInferer\n",
    "from monai.handlers import StatsHandler, from_engine\n",
    "from monai.engines import SupervisedTrainer\n",
    "from monai.data import DataLoader\n",
    "from monai.config import print_config\n",
    "from monai.apps import MedNISTDataset\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from ignite.engine import Engine, Events\n",
    "from ignite.metrics import Accuracy\n",
    "import tempfile\n",
    "import sys\n",
    "import shutil\n",
    "import os\n",
    "import logging\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup data directory\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create training experiment with MedNISTDataset and workflow\n",
    "\n",
    "The MedNIST dataset was gathered from several sets from [TCIA](https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions), [the RSNA Bone Age Challenge](https://www.rsna.org/education/ai-resources-and-training/ai-image-challenge/rsna-pediatric-bone-age-challenge-2017), and [the NIH Chest X-ray dataset](https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest).\n",
    "\n",
    "If you use the MedNIST dataset, please acknowledge the source."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up pre-processing transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = Compose(\n",
    "    [\n",
    "        LoadImaged(keys=\"image\"),\n",
    "        EnsureChannelFirstd(keys=\"image\"),\n",
    "        ScaleIntensityd(keys=\"image\"),\n",
    "        EnsureTyped(keys=\"image\"),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create MedNISTDataset for training\n",
    "`MedNISTDataset` inherits from MONAI `CacheDataset` and provides rich parameters to automatically download dataset and extract, and acts as normal PyTorch Dataset with cache mechanism."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_ds = MedNISTDataset(root_dir=root_dir, transform=transform, section=\"training\", download=True)\n",
    "# the dataset can work seamlessly with the pytorch native dataset loader,\n",
    "# but using monai.data.DataLoader has additional benefits of mutli-process\n",
    "# random seeds handling, and the customized collate functions\n",
    "train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pick images from MedNISTDataset to visualize and check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.subplots(3, 3, figsize=(8, 8))\n",
    "for i in range(9):\n",
    "    plt.subplot(3, 3, i + 1)\n",
    "    plt.imshow(train_ds[i * 5000][\"image\"][0].detach().cpu(), cmap=\"gray\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create training components - device, network, loss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "net = DenseNet121(pretrained=True, progress=False, spatial_dims=2, in_channels=1, out_channels=6).to(device)\n",
    "loss = torch.nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set different learning rate values for layers\n",
    "Please refer to the appendix at the end this notebook for the layers of `DenseNet121`.\n",
    "1. Set LR=1e-3 for the selected `class_layers` block.\n",
    "2. Set LR=1e-4 for convolution layers based on the filter where `conv.weight` is in the layer name.\n",
    "3. LR=1e-5 for other layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = generate_param_groups(\n",
    "    network=net,\n",
    "    layer_matches=[lambda x: x.class_layers, lambda x: \"conv.weight\" in x[0]],\n",
    "    match_types=[\"select\", \"filter\"],\n",
    "    lr_values=[1e-3, 1e-4],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the optimizer based on the parameter groups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = torch.optim.Adam(params, 1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the easiest training workflow and run\n",
    "\n",
    "Use MONAI SupervisedTrainer handlers to quickly set up a training workflow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": [
     "outputPrepend"
    ]
   },
   "outputs": [],
   "source": [
    "trainer = SupervisedTrainer(\n",
    "    device=device,\n",
    "    max_epochs=5,\n",
    "    train_data_loader=train_loader,\n",
    "    network=net,\n",
    "    optimizer=opt,\n",
    "    loss_function=loss,\n",
    "    inferer=SimpleInferer(),\n",
    "    key_train_metric={\"train_acc\": Accuracy(output_transform=from_engine([\"pred\", \"label\"]))},\n",
    "    train_handlers=StatsHandler(tag_name=\"train_loss\", output_transform=from_engine([\"loss\"], first=True)),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define a ignite handler to adjust LR in runtime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LrScheduler:\n",
    "    def attach(self, engine: Engine) -> None:\n",
    "        engine.add_event_handler(Events.EPOCH_COMPLETED, self)\n",
    "\n",
    "    def __call__(self, engine: Engine) -> None:\n",
    "        for i, param_group in enumerate(engine.optimizer.param_groups):\n",
    "            if i == 0:\n",
    "                param_group[\"lr\"] *= 0.1\n",
    "            elif i == 1:\n",
    "                param_group[\"lr\"] *= 0.5\n",
    "\n",
    "        print(\"LR values of 3 parameter groups: \", [g[\"lr\"] for g in engine.optimizer.param_groups])\n",
    "\n",
    "\n",
    "LrScheduler().attach(trainer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Execute the training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup data directory\n",
    "\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Appendix: layers of DenseNet 121 network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(net)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
